/*
 * Copyright (c) 2022. China Mobile (SuZhou) Software Technology Co.,Ltd. All rights reserved.
 * Lakehouse is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package com.chinamobile.cmss.lakehouse.service.sqlconsole.impl;

import com.chinamobile.cmss.lakehouse.common.Constants;
import com.chinamobile.cmss.lakehouse.common.dto.ExecuteSQLBean;
import com.chinamobile.cmss.lakehouse.common.dto.SqlConsoleSQLResult;
import com.chinamobile.cmss.lakehouse.common.dto.SqlConsoleSQLResult.SQLResult;
import com.chinamobile.cmss.lakehouse.common.dto.engine.LakehouseResponse;
import com.chinamobile.cmss.lakehouse.common.dto.engine.SQLTaskKillReq;
import com.chinamobile.cmss.lakehouse.common.dto.sqlconsole.LakehouseDBRequest;
import com.chinamobile.cmss.lakehouse.common.enums.EngineType;
import com.chinamobile.cmss.lakehouse.common.enums.HttpStatus;
import com.chinamobile.cmss.lakehouse.common.enums.SQLTypeEnum;
import com.chinamobile.cmss.lakehouse.common.enums.Status;
import com.chinamobile.cmss.lakehouse.common.enums.TaskStatusTypeEnum;
import com.chinamobile.cmss.lakehouse.common.exception.BaseException;
import com.chinamobile.cmss.lakehouse.common.utils.ClusterUtil;
import com.chinamobile.cmss.lakehouse.common.utils.PropertyUtils;
import com.chinamobile.cmss.lakehouse.common.utils.ServiceException;
import com.chinamobile.cmss.lakehouse.common.utils.SqlAnalyzer;
import com.chinamobile.cmss.lakehouse.common.utils.SqlSplitter;
import com.chinamobile.cmss.lakehouse.common.utils.TaskUtil;
import com.chinamobile.cmss.lakehouse.core.config.KubernetesConfiguration;
import com.chinamobile.cmss.lakehouse.core.handler.K8sUriHandler;
import com.chinamobile.cmss.lakehouse.dao.HiveMetastoreConfigDao;
import com.chinamobile.cmss.lakehouse.dao.SparkSQLTaskInfoDao;
import com.chinamobile.cmss.lakehouse.dao.entity.EngineTaskInfoEntity;
import com.chinamobile.cmss.lakehouse.service.db.DbOperationService;
import com.chinamobile.cmss.lakehouse.service.engine.EngineService;
import com.chinamobile.cmss.lakehouse.service.engine.SparkTaskService;
import com.chinamobile.cmss.lakehouse.service.redis.RedisOperateClient;
import com.chinamobile.cmss.lakehouse.service.sqlconsole.JdbcKillTaskService;
import com.chinamobile.cmss.lakehouse.service.sqlconsole.SQLTaskService;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.retry.annotation.Backoff;
import org.springframework.retry.annotation.Retryable;
import org.springframework.stereotype.Service;

@Slf4j
@Service
public class SQLTaskServiceImpl implements SQLTaskService {

    @Autowired
    private SparkSQLTaskInfoDao sqlTaskDao;

    @Autowired
    private EngineService engineService;

    @Autowired
    private JdbcKillTaskService jdbcKillTaskService;

    @Autowired
    private RedisOperateClient redisOperateClient;

    @Autowired
    private K8sUriHandler k8sUriHandler;

    @Autowired
    private DbOperationService dbOperationService;

    @Autowired
    private HiveMetastoreConfigDao metastoreConfigDao;

    @Autowired
    private SparkTaskService sparkTaskService;

    private String hiveDriverName = PropertyUtils.getString(Constants.HIVE_DRIVER_NAME);

    private void saveTaskInfo(ExecuteSQLBean req, String taskId, TaskStatusTypeEnum taskStatus) {
        EngineTaskInfoEntity taskInfo = sqlTaskDao.findByTaskId(taskId);
        if (taskInfo == null) {
            taskInfo = new EngineTaskInfoEntity();
            taskInfo.setInstance(req.getInstance());
            taskInfo.setDbName(req.getDbName());
            taskInfo.setEngineType(req.getEngineType());
            taskInfo.setTaskType("DDL");
            taskInfo.setSubmitTime(new Date());
            taskInfo.setTaskId(taskId);
            taskInfo.setSubmitUser(req.getSubmitUser());
            taskInfo.setSqlContent(req.getSqlContext());
            taskInfo.setFinishTime(new Date());
            taskInfo.setWaitTime((long) 0);
            taskInfo.setRunTime((long) 0);
            taskInfo.setRetryTimes(10);
        }
        taskInfo.setStatusTypeEnum(taskStatus);
        sqlTaskDao.save(taskInfo);
    }

    protected SQLResult acceptHiveSQLTask(ExecuteSQLBean req, Connection connection) {

        EngineTaskInfoEntity hiveTaskInfo = new EngineTaskInfoEntity();
        String taskId = TaskUtil.generateTaskId();
        hiveTaskInfo.setInstance(req.getInstance());
        hiveTaskInfo.setDbName(req.getDbName());
        hiveTaskInfo.setEngineType(req.getEngineType());
        hiveTaskInfo.setTaskType("DDL");
        hiveTaskInfo.setStatusTypeEnum(TaskStatusTypeEnum.RUNNING);
        hiveTaskInfo.setSubmitTime(new Date());
        hiveTaskInfo.setTaskId(taskId);
        hiveTaskInfo.setSubmitUser(req.getSubmitUser());
        hiveTaskInfo.setSqlContent(req.getSqlContext());
        hiveTaskInfo.setSubmitUserId(req.getUserId());
        hiveTaskInfo = sqlTaskDao.save(hiveTaskInfo);

        SQLResult sqlResult = new SQLResult();
        sqlResult.setTaskID(taskId);
        Statement statement = null;
        ResultSet resultSet = null;
        String sqlContent = req.getSqlContext();
        String switchDb = "use " + req.getDbName();
        try {
            statement = connection.createStatement();
            jdbcKillTaskService.getJDBCConfiguration(req.getSubmitUser()).saveStatement(taskId,
                statement);
            statement.execute(switchDb);

            boolean isResultSetAvailable = statement.execute(sqlContent);
            if (isResultSetAvailable) {
                resultSet = statement.getResultSet();
                int resultRows = 0;
                StringBuffer result = new StringBuffer();
                String sqlToken = sqlContent.split("\\s+")[0].toLowerCase();
                while (resultSet.next()) {
                    if (sqlToken.startsWith("desc") || sqlToken.startsWith("describe")) {
                        result.append(resultSet.getString(1)).append(" ").append(resultSet.getString(2)).append(" ").append(resultSet.getString(3))
                            .append("\n");
                    } else {
                        result.append(resultSet.getString(1)).append("\n");
                    }
                    resultRows++;
                }

                sqlResult.setResult(result);
                hiveTaskInfo.setResultRows(resultRows);
            } else {
                hiveTaskInfo.setResultRows(0);
            }
            hiveTaskInfo.setStatusTypeEnum(TaskStatusTypeEnum.FINISHED);
            hiveTaskInfo.setFinishTime(new Date());

            // waiting time is not considered for the hive sql
            hiveTaskInfo.setWaitTime((long) 0);
            hiveTaskInfo.setRunTime(
                ((System.currentTimeMillis() - hiveTaskInfo.getSubmitTime().getTime()) / 1000));
            sqlTaskDao.save(hiveTaskInfo);
        } catch (Throwable e) {
            log.error("Cannot run " + req.getSqlContext(), e);
            hiveTaskInfo.setRunTime(
                ((System.currentTimeMillis() - hiveTaskInfo.getSubmitTime().getTime()) / 1000));
            hiveTaskInfo.setStatusTypeEnum(TaskStatusTypeEnum.FAILED);
            hiveTaskInfo.setRetryTimes(10);
            sqlTaskDao.save(hiveTaskInfo);
            sqlResult.setMessage(e.getMessage());
            return sqlResult;
        } finally {
            if (resultSet != null) {
                try {
                    resultSet.close();
                } catch (SQLException e) {
                    /* ignored */
                }
            }
            if (statement != null) {
                try {
                    statement.close();
                } catch (SQLException e) {
                    /* ignored */
                }
            }
            jdbcKillTaskService.getJDBCConfiguration(req.getSubmitUser()).removeStatement(taskId);
        }

        updateTblInfo(req, taskId);

        return sqlResult;
    }

    @Override
    public SqlConsoleSQLResult dispatchSQLTask(ExecuteSQLBean req) {
        SqlConsoleSQLResult sqlConsoleSQLResult = new SqlConsoleSQLResult();
        Connection connection = null;
        try {
            Class.forName(hiveDriverName);
            connection = DriverManager.getConnection(k8sUriHandler.getExternalHiveUrl(), "hive", "hive");
        } catch (Exception e) {
            log.error("Fail to getConnection", e);
            throw new ServiceException(HttpStatus.GATEWAY_TIMEOUT.getCode(), Status.SQL_CONSOLE_ENGINE_CONN_FAILED_ERROR.getMessage());
        }

        List<SQLResult> resultList = new ArrayList<>();

        try {
            List<String> sqlList = new SqlSplitter().splitSql(req.getSqlContext());
            for (int i = 0; i < sqlList.size(); i++) {
                String taskId = TaskUtil.generateTaskId();
                SQLResult sqlResult = new SQLResult();
                String sql = sqlList.get(i).trim();
                log.info("Execute sql: " + sql);
                if (sql.length() != 0) {
                    try {
                        // parser sql
                        final SQLStatementParser hive = SQLParserUtils.createSQLStatementParser(sql, DbType.hive);
                        hive.parseStatement();
                    } catch (Exception e) {
                        log.error(
                            String.format("Fail to parse %s sql! %s", req.getEngineType(), e.getMessage()));
                        sqlResult.setTaskID(taskId);
                        sqlResult.setMessage(
                            String.format("Fail to parse %s sql! %s", req.getEngineType(), e.getMessage()));
                        saveTaskInfo(req, taskId, TaskStatusTypeEnum.FAILED);
                        resultList.add(sqlResult);
                        continue;
                    }
                }
                ExecuteSQLBean validReq = req;
                // every SQL will build a request
                validReq.setSqlContext(sql);

                String[] sqlTokens = sql.toLowerCase().split("\\s+");
                if (sqlTokens[0].equals(SQLTypeEnum.CREATE.getStatus())) {
                    // create database, create schema. to operator
                    if (ArrayUtils.contains(sqlTokens, "schema")
                        || ArrayUtils.contains(sqlTokens, "database")) {
                        String dbName;
                        if (ArrayUtils.contains(sqlTokens, "if")) {
                            dbName = sqlTokens[5];
                        } else {
                            dbName = sqlTokens[2];
                        }
                        LakehouseDBRequest dbRequest = LakehouseDBRequest.builder().databaseName(dbName)
                            .userIdentifier(validReq.getUserId()).userName(validReq.getSubmitUser())
                            .build();
                        saveTaskInfo(validReq, taskId, TaskStatusTypeEnum.RUNNING);
                        sqlResult.setTaskID(taskId);
                        try {
                            dbOperationService.createDatabase(dbRequest);
                            dbModifyLog(dbRequest.getUserName(), dbName, dbRequest.getUserIdentifier());
                        } catch (Exception e) {
                            sqlResult.setMessage(String.format(
                                "Error create database %s! The database already exists or is being used by another user",
                                dbName));
                            saveTaskInfo(validReq, taskId, TaskStatusTypeEnum.FAILED);
                            resultList.add(sqlResult);
                            sqlConsoleSQLResult.setResultsList(resultList);
                            return sqlConsoleSQLResult;
                        }
                        resultList.add(sqlResult);
                        saveTaskInfo(validReq, taskId, TaskStatusTypeEnum.FINISHED);
                        continue;
                    }

                    // create table tbName as select ,create table IF NOT EXISTS tbName as
                    // select or create external table tbName as select, create external table IF NOT EXISTS
                    // tbName as select
                    if (ArrayUtils.contains(sqlTokens, "select")) {
                        // to spark
                        String sparkTaskId = sparkTaskService.acceptTask(validReq);
                        sqlResult.setTaskID(sparkTaskId);
                        resultList.add(sqlResult);
                        String[] tblFullName = getTable(validReq);
                        tbModifyLog(sparkTaskId, Objects.requireNonNull(tblFullName)[0],
                            validReq.getSubmitUser(), Objects.requireNonNull(tblFullName)[1]);
                        continue;
                    }
                    // to hive
                    sqlResult = acceptHiveSQLTask(validReq, connection);
                    resultList.add(sqlResult);
                    continue;
                } else if (sqlTokens[0].equals(SQLTypeEnum.DROP.getStatus())) {
                    if ("schema".equals(sqlTokens[1]) || "database".equals(sqlTokens[1])) {
                        // drop database, drop schema. to operator
                        String dbName;
                        if (ArrayUtils.contains(sqlTokens, "if")) {
                            dbName = sqlTokens[4];
                        } else {
                            dbName = sqlTokens[2];
                        }
                        LakehouseDBRequest dbRequest = LakehouseDBRequest.builder().databaseName(dbName)
                            .userIdentifier(validReq.getUserId()).userName(validReq.getSubmitUser())
                            .build();
                        saveTaskInfo(validReq, taskId, TaskStatusTypeEnum.RUNNING);
                        sqlResult.setTaskID(taskId);
                        try {
                            dbOperationService.deleteDatabase(Arrays.asList(dbRequest));
                            dbModifyLog(dbRequest.getUserName(), dbName, dbRequest.getUserIdentifier());
                        } catch (Exception e) {
                            sqlResult.setMessage(String.format(
                                "Error drop database %s! The database not exists or You do not have permission to drop the database",
                                dbName));
                            saveTaskInfo(validReq, taskId, TaskStatusTypeEnum.FAILED);
                            resultList.add(sqlResult);
                            sqlConsoleSQLResult.setResultsList(resultList);
                            return sqlConsoleSQLResult;
                        }
                        resultList.add(sqlResult);
                        saveTaskInfo(validReq, taskId, TaskStatusTypeEnum.FINISHED);
                        continue;
                    }

                    // to hive
                    sqlResult = acceptHiveSQLTask(validReq, connection);
                    resultList.add(sqlResult);
                    continue;
                } else if (sqlTokens[0].equals(SQLTypeEnum.INSERT.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.SELECT.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.WITH.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.FROM.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.MSCK.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.ANALYZE.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.TRUNCATE.getStatus())) {
                    // to spark
                    String sparkTaskId = sparkTaskService.acceptTask(validReq);
                    sqlResult.setTaskID(sparkTaskId);
                    resultList.add(sqlResult);
                    continue;
                } else if (sqlTokens[0].equals(SQLTypeEnum.SHOW.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.ALTER.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.DESCRIBE.getStatus())
                    || sqlTokens[0].equals(SQLTypeEnum.DESC.getStatus())) {
                    // to hive
                    sqlResult = acceptHiveSQLTask(validReq, connection);
                    resultList.add(sqlResult);
                    continue;
                }
            }
            sqlConsoleSQLResult.setResultsList(resultList);
        } catch (Throwable e) {
            log.error("Cannot execute " + req.getSqlContext(), e);
            throw new BaseException("Cannot execute sql: " + req.getSqlContext(), e);
        } finally {
            if (connection != null) {
                try {
                    connection.close();
                } catch (SQLException e) {
                    /* ignored */
                }
            }
        }
        return sqlConsoleSQLResult;
    }

    @Retryable(backoff = @Backoff(delay = 1000, multiplier = 1.1))
    public synchronized void updateTask(EngineTaskInfoEntity po) {
        List<EngineTaskInfoEntity> tasks = sqlTaskDao.findAllByTaskId(po.getTaskId());
        if (tasks.isEmpty()) {
            return;
        }
        po.setVersion(tasks.get(0).getVersion());
        sqlTaskDao.save(po);
    }

    public LakehouseResponse killSparkTask(String taskId) {
        EngineTaskInfoEntity task = sqlTaskDao.findByTaskId(taskId);
        if (task == null) {
            return LakehouseResponse.fail(HttpStatus.NOT_FOUND.getCode(), Status.SQL_CONSOLE_TASK_NOT_FOUND_ERROR.getMessage(), false);
        }
        if (TaskStatusTypeEnum.RUNNING.equals(task.getStatusTypeEnum())
            || TaskStatusTypeEnum.SUBMITTED.equals(task.getStatusTypeEnum())) {
            // if running, use api to kill pod
            SQLTaskKillReq req = new SQLTaskKillReq();
            req.setMaster(KubernetesConfiguration.getK8sDefaultClusterInfo());
            req.setNamespace(ClusterUtil.convertCluster2Namespace(task.getInstance()));
            req.setDriverPodId(task.getDriverPodId());
            LakehouseResponse<Boolean> send = engineService.killSQLTaskOnK8s(req);
            if (send.getData()) {
                // mark duration 0
                task.setFinishTime(new Date());
                task.setStatusTypeEnum(TaskStatusTypeEnum.CANCELLED);
                updateTask(task);
                String redisKey = redisOperateClient.generateRedisKey(task.getInstance());
                Map<String, Long> runningTaskChange = new HashMap<>();
                runningTaskChange.put(redisKey, -1L);
                redisOperateClient.updateRunningTask(runningTaskChange);
            }
            return send;
        } else if (TaskStatusTypeEnum.FINISHED.equals(task.getStatusTypeEnum())
            || TaskStatusTypeEnum.FAILED.equals(task.getStatusTypeEnum())
            || TaskStatusTypeEnum.CANCELLED.equals(task.getStatusTypeEnum())) {
            // not support for ended status, return true with nothing.
            return LakehouseResponse.success(true);
        } else {
            task.setStatusTypeEnum(TaskStatusTypeEnum.CANCELLED);
            Date now = new Date();
            // mark duration 0
            task.setSubmitTime(now);
            task.setFinishTime(now);
            updateTask(task);
            return LakehouseResponse.success(true);
        }
    }

    @Override
    public LakehouseResponse killSQLTask(String taskId) {
        Map<String, Object> result = new HashMap<>();
        EngineTaskInfoEntity taskInfoPO = sqlTaskDao.findByTaskId(taskId);
        if (taskInfoPO == null) {
            log.error("Task {} not found!", taskId);
            return LakehouseResponse.fail(HttpStatus.NOT_FOUND.getCode(), Status.SQL_CONSOLE_TASK_NOT_FOUND_ERROR.getMessage(), false);
        }
        String engineType = taskInfoPO.getEngineType();

        // Spark and DDL SQL use hive jdbc cancel api
        if (engineType.equals(EngineType.SPARK.getType()) && "DDL".equals(taskInfoPO.getTaskType())) {
            // TODO: support presto and flink engine
            return jdbcKillTaskService.cancelJdbcSQLTask(taskId);
        }

        // Other SQL kill K8S POD (Spark)
        return killSparkTask(taskId);
    }

    private void dbModifyLog(String submitUser, String dbName, String userIdentifier) {
        if (CollectionUtils.isEmpty(metastoreConfigDao.findDbChangeLog(dbName))) {
            log.info("Insert log for database :{}", dbName);
            metastoreConfigDao.insertDbChangeLog(submitUser, dbName, userIdentifier);
            return;
        }
        log.info("Update log for database {}", dbName);
        metastoreConfigDao.updateDbChangeLog(submitUser, dbName, userIdentifier);
    }

    public String[] getTable(ExecuteSQLBean req) {

        String[] sqlTokens = req.getSqlContext().split("\\s+");
        String[] tbInfo = null;
        String[] tb = null;
        try {
            if ("create".equalsIgnoreCase(sqlTokens[0])) {
                tb = SqlAnalyzer.getDatabaseAndTables(req.getSqlContext()).get(0).split("\\.");
            } else if ("alter".equalsIgnoreCase(sqlTokens[0])) {
                if (sqlTokens.length > 4 && "rename".equalsIgnoreCase(sqlTokens[3])
                    && "to".equalsIgnoreCase(sqlTokens[4])) {
                    tb = sqlTokens[5].split("\\.");
                } else {
                    tb = sqlTokens[2].split("\\.");
                }
            } else {
                return null;
            }
        } catch (Exception e) {
            log.error("Failed to get table info");
            // no op
        }
        if (tb.length == 3) {
            String tbName = tb[2];
            String dbName = tb[1];
            tbInfo = new String[] {tbName, dbName};
            return tbInfo;
        }

        String tbName = tb.length == 2 ? tb[1] : tb[0];
        String dbName = tb.length == 2 ? tb[0] : req.getDbName();
        tbInfo = new String[] {tbName, dbName};

        return tbInfo;
    }

    void tbModifyLog(String taskId, String tbName, String submetUser, String dbName) {
        if (CollectionUtils.isEmpty(metastoreConfigDao.findChangeLog(dbName, tbName))) {
            log.info("Insert cluster: {} table: {} modification log", dbName, tbName);
            metastoreConfigDao.insertChangeLog(taskId, tbName, submetUser, dbName);
            return;
        }
        log.info("Updata cluster: {} table: {} modification log", dbName, tbName);
        metastoreConfigDao.updateChangeLog(taskId, tbName, submetUser, dbName);
    }

    void updateTblInfo(ExecuteSQLBean req, String taskId) {
        String[] tb = getTable(req);
        try {
            if (tb != null) {
                tbModifyLog(taskId, tb[0], req.getSubmitUser(), tb[1]);
            }
        } catch (Exception e) {
            log.error("insert or update cluster: {} table: {} modification log failed! :{}", tb[1], tb[0],
                e.getMessage());
        }
    }
}
